"""This module contains vulnerability types, Enums, nodes and helpers."""

import os
from collections import namedtuple
from enum import Enum

from lib.pyt.core.node_types import YieldNode
from lib.pyt.vulnerabilities.trigger_definitions_parser import parse_rules

default_taint_config_file = os.path.join(
    os.path.dirname(__file__), "..", "vulnerability_definitions", "taint.config"
)
source_sink_rules = {}
all_rules = parse_rules(default_taint_config_file)
for rule in all_rules:
    for source in rule.sources:
        for sink in rule.sinks:
            source_sink_rules[f"{source}:{sink}"] = rule


class VulnerabilityType(Enum):
    FALSE = 0
    SANITISED = 1
    TRUE = 2
    UNKNOWN = 3


def vuln_factory(vulnerability_type):
    if vulnerability_type == VulnerabilityType.UNKNOWN:
        return UnknownVulnerability
    elif vulnerability_type == VulnerabilityType.SANITISED:
        return SanitisedVulnerability
    else:
        return Vulnerability


def _get_reassignment_str(reassignment_nodes):
    reassignments = ""
    if reassignment_nodes:
        reassignments += "\nReassigned in:\n\t"
        reassignments += "\n\t".join(
            [
                "File: "
                + node.path
                + "\n"
                + "\t > Line "
                + str(node.line_number)
                + ": "
                + node.label
                for node in reassignment_nodes
            ]
        )
    return reassignments


class Vulnerability:
    def __init__(
        self,
        source,
        source_trigger_word,
        source_type,
        sink,
        sink_trigger_word,
        sink_type,
        sink_args,
        reassignment_nodes,
    ):
        """Set source and sink information."""
        self.source = source
        self.source_type = source_type
        self.source_trigger_word = source_trigger_word
        self.sink = sink
        self.sink_type = sink_type
        self.sink_trigger_word = sink_trigger_word
        self.reassignment_nodes = reassignment_nodes
        self._remove_non_propagating_yields()

    def _remove_non_propagating_yields(self):
        """Remove yield with no variables e.g. `yield 123` and plain `yield` from vulnerability."""
        for node in list(self.reassignment_nodes):
            if isinstance(node, YieldNode) and len(node.right_hand_side_variables) == 1:
                self.reassignment_nodes.remove(node)

    def __str__(self):
        """Pretty printing of a vulnerability."""
        source_path = self.source.path.split("/")[-1]
        sink_path = self.sink.path.split("/")[-1]
        return (
            "File: {}\n"
            ' > User input at line {}, source "{}":\n'
            "\t {}\nFile: {}\n"
            ' > reaches line {}, sink "{}":\n'
            "\t{}".format(
                source_path,
                self.source.line_number,
                self.source_trigger_word,
                self.source.label,
                sink_path,
                self.sink.line_number,
                self.sink_trigger_word,
                self.sink.label,
            )
        )

    def as_dict(self):
        source_path = self.source.path.split("/")[-1]
        sink_path = self.sink.path.split("/")[-1]
        rule_id = self.source_type

        rule_name = f"Data flow from {self.source_type} to {self.sink_type}"
        description = f"User controlled data flow from the source `{source_path}:{self.source.line_number}` to the sink `{sink_path}:{self.sink.line_number}`"
        rule = source_sink_rules.get(f"{self.source_type}:{self.sink_type}")
        severity = "HIGH"
        owasp_category = ""
        cwe_category = ""
        if rule:
            rule_id = rule.code
            rule_name = rule.name
            severity = rule.severity
            message_format = rule.message_format
            owasp_category = rule.owasp_category
            cwe_category = rule.cwe_category
            sources = f"{source_path}:{self.source.line_number}"
            if self.source_type == "Framework_Parameter":
                source_parameter = self.source.label
                sources = (
                    f"{source_parameter} in {source_path}:{self.source.line_number}"
                )
            message_format = message_format.replace("{$sources}", sources)
            message_format = message_format.replace(
                "{$sinks}", f"{sink_path}:{self.sink.line_number}"
            )
            description = message_format

        return {
            "rule_id": rule_id,
            "rule_name": rule_name,
            "severity": severity,
            "cwe_category": cwe_category,
            "owasp_category": owasp_category,
            "source": self.source.as_dict(),
            "source_trigger_word": self.source_trigger_word,
            "source_label": self.source.label,
            "source_type": self.source_type,
            "sink": self.sink.as_dict(),
            "sink_trigger_word": self.sink_trigger_word,
            "sink_label": self.sink.label,
            "sink_type": self.sink_type,
            "type": self.__class__.__name__,
            "reassignment_nodes": [node.as_dict() for node in self.reassignment_nodes],
            "description": description,
            "short_description": description,
        }


class SanitisedVulnerability(Vulnerability):
    def __init__(self, confident, sanitiser, **kwargs):
        super().__init__(**kwargs)
        self.confident = confident
        self.sanitiser = sanitiser

    def __str__(self):
        """Pretty printing of a vulnerability."""
        return (
            super().__str__()
            + "\nThis vulnerability is "
            + ("" if self.confident else "potentially ")
            + "sanitised by: "
            + str(self.sanitiser)
        )

    def as_dict(self):
        output = super().as_dict()
        output["sanitiser"] = self.sanitiser.as_dict()
        output["confident"] = self.confident
        return output


class UnknownVulnerability(Vulnerability):
    def __init__(self, unknown_assignment, **kwargs):
        super().__init__(**kwargs)
        self.unknown_assignment = unknown_assignment

    def as_dict(self):
        output = super().as_dict()
        output["unknown_assignment"] = self.unknown_assignment.as_dict()
        return output

    def __str__(self):
        """Pretty printing of a vulnerability."""
        return (
            super().__str__()
            + "\nThis vulnerability is unknown due to: "
            + str(self.unknown_assignment)
        )


Sanitiser = namedtuple("Sanitiser", ("trigger_word", "cfg_node"))


Triggers = namedtuple("Triggers", ("sources", "sinks", "sanitiser_dict"))


class TriggerNode:
    def __init__(self, trigger, cfg_node, secondary_nodes=[]):
        self.trigger = trigger
        self.cfg_node = cfg_node
        self.secondary_nodes = secondary_nodes

    @property
    def trigger_word(self):
        return self.trigger.trigger_word

    @property
    def source_type(self):
        if hasattr(self.trigger, "source_type"):
            return self.trigger.source_type
        return None

    @property
    def sink_type(self):
        if hasattr(self.trigger, "sink_type"):
            return self.trigger.sink_type
        return None

    @property
    def sanitisers(self):
        return self.trigger.sanitisers if hasattr(self.trigger, "sanitisers") else []

    def append(self, cfg_node):
        if not cfg_node == self.cfg_node:
            if self.secondary_nodes and cfg_node not in self.secondary_nodes:
                self.secondary_nodes.append(cfg_node)
            elif not self.secondary_nodes:
                self.secondary_nodes = [cfg_node]

    def __repr__(self):
        output = "TriggerNode("

        if self.trigger_word:
            output = "{} trigger_word is {}, \n".format(output, self.trigger_word)
        if self.source_type:
            output += f"source type {self.source_type}\n"
        if self.sink_type:
            output += f"sink type {self.sink_type}\n"
        return (
            output
            + "sanitisers are {}, ".format(self.sanitisers)
            + "cfg_node is {})\n".format(self.cfg_node)
        )
